## part of the code is from and modified: https://github.com/artemyk/ibsgd 

import tensorflow as tf
import numpy as np
import os
from tensorflow.keras import backend as K
from. import utils
from tensorflow.keras import layers

def generate_data():

    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

    # normalization to 1
    x_train = (x_train / 255.0)
    x_test = (x_test / 255.0)

    Y_train = tf.keras.utils.to_categorical(y_train, 10)
    Y_test = tf.keras.utils.to_categorical(y_test, 10)


    return x_train,Y_train,x_test,Y_test

class getMIOutput(tf.keras.callbacks.Callback):
    def __init__(self, trn, tst, Z_layer_idx, num_selection, do_save_func=None, *kargs, **kwargs):
        super(getMIOutput, self).__init__(*kargs, **kwargs)
        self.trn = trn
        self.tst = tst
        self.Z_layer_idx = Z_layer_idx
        self.num_selection = num_selection
        self.do_save_func = do_save_func # control the saved epoch
        self.layer_values = []
        self.layerixs = []
        self.layerfuncs = []

    def on_train_begin(self, logs=None):
        for lndx, l in enumerate(self.model.layers):
            self.layerixs.append(lndx)
            self.layer_values.append(lndx)
            self.layerfuncs.append(K.function(self.model.inputs, [l.output,]))

    def on_epoch_end(self, epoch, logs=None):
        if self.do_save_func is not None and not self.do_save_func(epoch):
            return

        data = {
            'activity': []  # Activity in each layer
        }

        for lndx, layerix in enumerate(self.layerixs):
            if lndx == self.Z_layer_idx:
                clayer = self.model.layers[layerix]
                activity_tst = self.layerfuncs[lndx]([self.trn[:self.num_selection],])[0]
                data['activity'].append(activity_tst)

        # Convert the list of numpy arrays to a single numpy array for npy compatibility
        activity_tst_array = np.array(data['activity']).reshape(self.num_selection, -1)
        
        # Save the numpy array to an npy file
        filename = f"CF_epoch_{epoch}_z_{self.Z_layer_idx}.npy"
        filepath = os.path.join('savedata', filename)
        np.save(filepath, activity_tst_array)

        print(f"Saved data for epoch {epoch} to {filename}")
        
        
def do_report_CF(epoch):
    # Only log activity for some epochs.  Mainly this is to make things run faster.
    if epoch < 100:       # Log for all first 20 epochs
        return True
    elif epoch < 200:    # Then for every 5th epoch
        return (epoch % 5 == 0)
    elif epoch < 2000:    # Then every 10th
        return (epoch % 20 == 0)
    else:                # Then every 100th
        return (epoch % 100 == 0)
    
    

def train_model(config):
    # Get data
    x_1train,Y_1train,x_1test,Y_1test = generate_data()

    # Model training
    tf.keras.backend.clear_session()
    tf.random.set_seed(42)

    input_layer = layers.Input(shape=(32, 32, 3))
    # First Convolutional Block
    encoder_1 = layers.Conv2D(48, (3, 3), activation='relu', padding='same')(input_layer)
    encoder_1 = layers.Conv2D(48, (3, 3), activation='relu', padding='same')(encoder_1)
    BN_1 = layers.BatchNormalization()(encoder_1)
    maxpool_1 = layers.MaxPooling2D((2, 2))(BN_1)
    dropout_1 = layers.Dropout(0.5)(maxpool_1)
    encoder_2 = layers.Conv2D(96, (3, 3), activation='relu', padding='same')(dropout_1)
    encoder_2 = layers.Conv2D(96, (3, 3), activation='relu', padding='same')(encoder_2)
    maxpool_2 = layers.MaxPooling2D((2, 2))(encoder_2)
    global_avg_pool = layers.GlobalAveragePooling2D()(encoder_2)

    # Fully Connected Layers
    dense_1 = layers.Dense(512, activation='relu')(global_avg_pool)
    dense_1 = layers.Dense(256, activation='relu')(dense_1)
    #dropout_3 = layers.Dropout(0.5)(dense_1)

    # Output Layer
    CE_output = layers.Dense(10, activation='softmax', name='CE')(dense_1)

    model = tf.keras.Model(inputs=input_layer, outputs=[CE_output])

    # Use the optimizer and learning rate from the config
    if config["optimizer"] == "SGD":
        opt = tf.keras.optimizers.SGD(learning_rate=config["lr"])
    elif config["optimizer"] == "Adam":
        opt = tf.keras.optimizers.Adam(learning_rate=config["lr"])
    # Add other optimizers as needed

    model.compile(optimizer=opt,
                  loss={'CE': 'categorical_crossentropy'},
                  metrics={'CE': 'accuracy'})

    reporter = getMIOutput(trn=x_1train,
                           tst=x_1test,
                           Z_layer_idx=config["z_idx"],  # Use z_idx from config
                           num_selection=5000,
                           do_save_func=do_report_CF)

    history = model.fit(x=x_1train, y=Y_1train,
                        batch_size=config["batch_size"],  # Use batch size from config
                        epochs=config["epoch"],  # Use number of epochs from config
                        verbose=0,
                        validation_data=(x_1test, Y_1test),
                        callbacks=[reporter,])

    # Print the final generalization gap (train accuracy - test accuracy / train loss - test loss)
    final_train_acc = history.history['accuracy'][-1]
    final_val_acc = history.history['val_accuracy'][-1]
    final_train_loss = history.history['loss'][-1]
    final_val_loss = history.history['val_loss'][-1]

    generalization_gap_acc = final_train_acc - final_val_acc
    generalization_gap_loss = final_train_loss - final_val_loss

    print(f"Final train (Accuracy): {final_train_acc}")
    print(f"Final test (Accuracy): {final_val_acc}")
    print(f"Final Generalization Gap (Accuracy): {generalization_gap_acc}")
    print(f"Final Generalization Gap (Loss): {generalization_gap_loss}")


